{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "A9Dy9CPcPab4"
   },
   "source": [
    "This recipe follows initially the scikit-learn tutorial on working with text data: https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "rVStCLxetO_B"
   },
   "outputs": [],
   "source": [
    "from sklearn.datasets import fetch_20newsgroups"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "id": "VJCtjWYKPC5U",
    "outputId": "b371edfe-fde5-4a04-8b8a-fd444d921b08"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading 20news dataset. This may take a few minutes.\n",
      "Downloading dataset from https://ndownloader.figshare.com/files/5975967 (14 MB)\n"
     ]
    }
   ],
   "source": [
    "categories = ['alt.atheism', 'soc.religion.christian', 'comp.graphics', 'sci.med']\n",
    "twenty_train = fetch_20newsgroups(\n",
    "    subset='train',\n",
    "    categories=categories,\n",
    "    shuffle=True,\n",
    "    random_state=42\n",
    "  )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Lex3NsBSPSdU"
   },
   "source": [
    "It's a small dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "id": "mB5zYqo6PE9k",
    "outputId": "9cdf1f5b-7520-4ec7-dde7-dc756d4dd482"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2257"
      ]
     },
     "execution_count": 3,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(twenty_train.filenames)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 68
    },
    "id": "WedME_XCPRKB",
    "outputId": "7b419e25-a3f9-472e-ab93-ef5a61a5c7ae"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "From: sd345@city.ac.uk (Michael Collier)\n",
      "Subject: Converting images to HP LaserJet III?\n",
      "Nntp-Posting-Host: hampton\n"
     ]
    }
   ],
   "source": [
    "print(\"\\n\".join(twenty_train.data[0].split(\"\\n\")[:3]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mALQhYvMSbrZ"
   },
   "source": [
    "# Using a bag-of-words approach with a classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "Jlp8DsbfRAYY"
   },
   "outputs": [],
   "source": [
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.feature_extraction.text import CountVectorizer\n",
    "from sklearn.feature_extraction.text import TfidfTransformer\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "\n",
    "text_clf = Pipeline([\n",
    "  ('vect', CountVectorizer()),\n",
    "  ('tfidf', TfidfTransformer()),\n",
    "  ('clf', RandomForestClassifier()),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 408
    },
    "id": "DcDpL_6EQYAM",
    "outputId": "54df2e19-41bd-405b-9fb7-3aa1230fbdfc"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Pipeline(memory=None,\n",
       "         steps=[('vect',\n",
       "                 CountVectorizer(analyzer='word', binary=False,\n",
       "                                 decode_error='strict',\n",
       "                                 dtype=<class 'numpy.int64'>, encoding='utf-8',\n",
       "                                 input='content', lowercase=True, max_df=1.0,\n",
       "                                 max_features=None, min_df=1,\n",
       "                                 ngram_range=(1, 1), preprocessor=None,\n",
       "                                 stop_words=None, strip_accents=None,\n",
       "                                 token_pattern='(?u)\\\\b\\\\w\\\\w+\\\\b',\n",
       "                                 tokenizer=None, vocabulary=Non...\n",
       "                 RandomForestClassifier(bootstrap=True, ccp_alpha=0.0,\n",
       "                                        class_weight=None, criterion='gini',\n",
       "                                        max_depth=None, max_features='auto',\n",
       "                                        max_leaf_nodes=None, max_samples=None,\n",
       "                                        min_impurity_decrease=0.0,\n",
       "                                        min_impurity_split=None,\n",
       "                                        min_samples_leaf=1, min_samples_split=2,\n",
       "                                        min_weight_fraction_leaf=0.0,\n",
       "                                        n_estimators=100, n_jobs=None,\n",
       "                                        oob_score=False, random_state=None,\n",
       "                                        verbose=0, warm_start=False))],\n",
       "         verbose=False)"
      ]
     },
     "execution_count": 6,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text_clf.fit(twenty_train.data, twenty_train.target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "id": "YKrPF-VYR_1e",
    "outputId": "e0ae0588-d30d-4b34-b8ee-8c15bb49e6fe"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8009320905459387"
      ]
     },
     "execution_count": 7,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "twenty_test = fetch_20newsgroups(\n",
    "    subset='test',\n",
    "    categories=categories,\n",
    "    shuffle=True,\n",
    "    random_state=42\n",
    ")\n",
    "predicted = text_clf.predict(twenty_test.data)\n",
    "np.mean(predicted == twenty_test.target)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AphYch6gSh5i"
   },
   "source": [
    "# Using a word embedding with a classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "id": "WMR8vAXBbhq0",
    "outputId": "69499c36-041a-4ec6-c5f7-36f0a122cd6b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mounted at /gdrive\n"
     ]
    }
   ],
   "source": [
    "# if you use this in colab you can store the embeddings in your google drive. \n",
    "# This is optional. \n",
    "# If you don't want to do it, just remove the /gdrive part in the next cells\n",
    "from google.colab import drive\n",
    "drive.mount('/gdrive')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 52
    },
    "id": "BBSlCzAWYmIz",
    "outputId": "dbfb86fa-dd8c-4812-bc69-20464b4540f9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: wget in /usr/local/lib/python3.6/dist-packages (3.2)\n"
     ]
    },
    {
     "data": {
      "application/vnd.google.colaboratory.intrinsic+json": {
       "type": "string"
      },
      "text/plain": [
       "'/gdrive/My Drive/embeddings/wiki.en.vec'"
      ]
     },
     "execution_count": 23,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "!pip install wget\n",
    "import os\n",
    "import wget\n",
    "filepath = '/gdrive/My Drive/embeddings/wiki.en.vec'\n",
    "if not os.path.isfile(filepath):\n",
    "  wget.download(\n",
    "      'https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.en.vec',\n",
    "      filepath\n",
    "  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 71
    },
    "id": "sIPfOVmoS5tB",
    "outputId": "b356945c-51fd-4fe6-bf73-81ef519d994d"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/smart_open/smart_open_lib.py:252: UserWarning: This function is deprecated, use smart_open.open instead. See the migration notes for details: https://github.com/RaRe-Technologies/smart_open/blob/master/README.rst#migrating-to-the-new-open-function\n",
      "  'See the migration notes for details: %s' % _MIGRATION_NOTES_URL\n"
     ]
    }
   ],
   "source": [
    "from gensim.models import KeyedVectors\n",
    "\n",
    "model = KeyedVectors.load_word2vec_format(\n",
    "    filepath,\n",
    "    binary=False, encoding='utf8'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 88
    },
    "id": "43F4upWkCFaB",
    "outputId": "73c3cb58-ddd1-4def-e7a0-fd3797bd0618"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:7: DeprecationWarning: Call to deprecated `wv` (Attribute will be removed in 4.0.0, use self instead).\n",
      "  import sys\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(1, 300)"
      ]
     },
     "execution_count": 26,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "from tensorflow.keras.preprocessing.text import text_to_word_sequence\n",
    "\n",
    "def embed_text(text: str):\n",
    "  vector_list = [\n",
    "    model.wv[w].reshape(-1, 1) for w in text_to_word_sequence(text)\n",
    "    if w in model.wv\n",
    "  ]\n",
    "  if len(vector_list) > 0:\n",
    "    return np.mean(\n",
    "        np.concatenate(vector_list, axis=1),\n",
    "        axis=1\n",
    "    ).reshape(1, 300)\n",
    "  else:\n",
    "    return np.zeros(shape=(1, 300))\n",
    "\n",
    "\n",
    "embed_text('training run').shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 71
    },
    "id": "V3xnBruIOSxb",
    "outputId": "b8229485-b636-4be9-ccf2-6c27210e010a"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:7: DeprecationWarning: Call to deprecated `wv` (Attribute will be removed in 4.0.0, use self instead).\n",
      "  import sys\n"
     ]
    }
   ],
   "source": [
    "train_transformed = np.concatenate(\n",
    "    [embed_text(t) for t in twenty_train.data]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "id": "26NodEUqO_cB",
    "outputId": "66be743e-e982-49a8-bf8c-181c0beaa7db"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2257, 300)"
      ]
     },
     "execution_count": 28,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_transformed.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "id": "JvhJL1PgOd14"
   },
   "outputs": [],
   "source": [
    "rf = RandomForestClassifier().fit(train_transformed, twenty_train.target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 71
    },
    "id": "wpU0y4CCO9ZX",
    "outputId": "4ffc4eef-b643-4b96-e286-48e8923c95a1"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:7: DeprecationWarning: Call to deprecated `wv` (Attribute will be removed in 4.0.0, use self instead).\n",
      "  import sys\n"
     ]
    }
   ],
   "source": [
    "test_transformed = np.concatenate(\n",
    "    [embed_text(t) for t in twenty_test.data]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "id": "DzFbM0RHPc1h",
    "outputId": "463868e0-fdba-4971-d816-551a7032abcc"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8608521970705726"
      ]
     },
     "execution_count": 31,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predicted = rf.predict(test_transformed)\n",
    "np.mean(predicted == twenty_test.target)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PY8WpBQ-eWLh"
   },
   "source": [
    "# Keras model with embedding layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "id": "vrieNKV5e3Z8"
   },
   "outputs": [],
   "source": [
    "from tensorflow.keras import layers\n",
    "\n",
    "embedding = layers.Embedding(\n",
    "    input_dim=5000, \n",
    "    output_dim=50, \n",
    "    input_length=500\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "id": "MInr1-EifFgQ"
   },
   "outputs": [],
   "source": [
    "from tensorflow.keras.preprocessing.text import Tokenizer\n",
    "\n",
    "tokenizer = Tokenizer(num_words=5000)\n",
    "tokenizer.fit_on_texts(twenty_train.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "id": "pEXw_q3yfae0"
   },
   "outputs": [],
   "source": [
    "X_train = tokenizer.texts_to_sequences(twenty_train.data)\n",
    "X_test = tokenizer.texts_to_sequences(twenty_test.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "id": "DIVyjcRAmX0m"
   },
   "outputs": [],
   "source": [
    "from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
    "\n",
    "X_train = pad_sequences(X_train, padding='post', maxlen=500)\n",
    "X_test = pad_sequences(X_test, padding='post', maxlen=500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 289
    },
    "id": "jPBjYiageUbQ",
    "outputId": "d13e94a9-48f9-4b82-ae2d-bc8b57872e4f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding (Embedding)        (None, 500, 50)           250000    \n",
      "_________________________________________________________________\n",
      "flatten (Flatten)            (None, 25000)             0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 10)                250010    \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 4)                 44        \n",
      "=================================================================\n",
      "Total params: 500,054\n",
      "Trainable params: 500,054\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.losses import SparseCategoricalCrossentropy\n",
    "from tensorflow.keras import regularizers\n",
    "\n",
    "model = Sequential()\n",
    "model.add(embedding)\n",
    "model.add(layers.Flatten())\n",
    "model.add(layers.Dense(\n",
    "    10,\n",
    "    activation='relu',\n",
    "    kernel_regularizer=regularizers.l1_l2(l1=1e-5, l2=1e-4)\n",
    "))\n",
    "model.add(layers.Dense(len(categories), activation='softmax'))\n",
    "model.compile(optimizer='adam',\n",
    "              loss=SparseCategoricalCrossentropy(),\n",
    "              metrics=['accuracy'])\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 394
    },
    "id": "JVBg4RZTflSJ",
    "outputId": "ff0ac6a9-92a7-435a-8c53-2860261424e8"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "71/71 [==============================] - 3s 44ms/step - loss: 1.3417 - accuracy: 0.3345\n",
      "Epoch 2/10\n",
      "71/71 [==============================] - 4s 60ms/step - loss: 0.7579 - accuracy: 0.7461\n",
      "Epoch 3/10\n",
      "71/71 [==============================] - 4s 62ms/step - loss: 0.1984 - accuracy: 0.9730\n",
      "Epoch 4/10\n",
      "71/71 [==============================] - 5s 71ms/step - loss: 0.0950 - accuracy: 0.9969\n",
      "Epoch 5/10\n",
      "71/71 [==============================] - 5s 65ms/step - loss: 0.0707 - accuracy: 0.9987\n",
      "Epoch 6/10\n",
      "71/71 [==============================] - 5s 70ms/step - loss: 0.0583 - accuracy: 0.9996\n",
      "Epoch 7/10\n",
      "71/71 [==============================] - 5s 67ms/step - loss: 0.0514 - accuracy: 1.0000\n",
      "Epoch 8/10\n",
      "71/71 [==============================] - 5s 74ms/step - loss: 0.0463 - accuracy: 1.0000\n",
      "Epoch 9/10\n",
      "71/71 [==============================] - 5s 67ms/step - loss: 0.0422 - accuracy: 1.0000\n",
      "Epoch 10/10\n",
      "71/71 [==============================] - 5s 70ms/step - loss: 0.0388 - accuracy: 1.0000\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fcd96558c88>"
      ]
     },
     "execution_count": 37,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(X_train, twenty_train.target, epochs=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "id": "XJ4HbaIYoWur",
    "outputId": "ef58ecce-57fd-47b0-dcb1-383a04c12318"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8695073235685752"
      ]
     },
     "execution_count": 38,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predicted = model.predict(X_test).argmax(axis=1)\n",
    "np.mean(predicted == twenty_test.target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "G1O2rHUKakdE"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Classifying_newsgroups.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
